CIFAR-10 is a dataset which comprises 60,000 32x32 colour images in 10 classes/categories, with 6000 images per class. There are 50000 training images and 10000 test images in total (i.e. 5000 training and 1000 test images per class).
The categories are:
The classes are completely mutually exclusive. There is no overlap between automobiles and trucks. "Automobile" includes sedans, SUVs, things of that sort. "Truck" includes only big trucks. Neither includes pickup trucks.
This ensures that each class is isolated from the rest.
Since we have been tasked with image generation, GANs come to mind. There are a couple of network architectures we could employ:
Due to time constraints, I was only able to try out DCGAN.
# Data Manipulation Dependencies
import glob
import numpy as np
import pandas as pd
import tensorflow as tf
from math import floor
# Graphing Dependencies
from matplotlib import pyplot as plt
import seaborn as sns
from tensorflow.keras.utils import plot_model
# Image Manipulation Dependencies
import imageio
from PIL import Image, ImageEnhance
from skimage.transform import resize
from tensorflow.keras.preprocessing.image import ImageDataGenerator
# Image Clustering Dependencies
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
# Deep Learning Dependencies
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.models import Model, Sequential, load_model
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import *
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.callbacks import EarlyStopping
# GAN Evaluation Dependencies
from scipy.linalg import sqrtm
from tensorflow.keras.applications.inception_v3 import InceptionV3, preprocess_input
# Miscellaneous Dependencies
import os
from abc import ABC, abstractmethod
from pathlib import Path
from warnings import filterwarnings
filterwarnings('ignore')
# Define Constants
ARTIFACT_STORE = './gans/'
CIFAR10_CLASSES = ('Airplane', 'Automobile', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck')
# Converts image data form 0..255 to -1..1
def squeeze_images(img):
img = img.astype(np.float64)
img /= 127.5
img -= 1.0
return img
# Converts image data form -1..1 to 0..255
def unsqueeze_images(img):
img += 1.0
img *= 127.5
img = img.astype(np.uint8)
return img
# Loads the images from the CIFAR-10 dataset,
# combining the training and testing data
def load_data():
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
X = np.concatenate((X_train, X_test), axis=0)
y = np.concatenate((y_train, y_test), axis=0)
return X, y
# Isolates all the images of a particular class
def partition_data(class_):
X, y = load_data()
index = CIFAR10_CLASSES.index(class_)
return X[y.flatten() == index]
def sub_partition_data(class_, n_partitions=10, n_samples=8, drop=None):
X = partition_data(class_)
feat = X.reshape(X.shape[0], -1)
X_reduced = PCA(n_components=10, random_state=0).fit_transform(feat)
model = KMeans(n_clusters=n_partitions, random_state=0)
mask = model.fit_predict(X_reduced)
if drop:
return X[~np.isin(mask, drop)]
else:
fig, ax = plt.subplots(n_partitions, n_samples, figsize=(12, 8))
for r in range(n_partitions):
group = X[mask == r]
ax[r, 0].set_title(f'Group {r} - Count: {group.shape[0]}')
for c in range(n_samples):
ax[r, c].imshow(group[r * n_samples + c])
ax[r, c].axis('off')
plt.show()
def sharpen_images(imgs, degree=2.0):
lst = []
if len(imgs.shape) == 4:
for img in imgs:
im = Image.fromarray(img)
enhancer = ImageEnhance.Sharpness(im)
im = enhancer.enhance(degree)
lst.append(np.expand_dims(np.array(im), axis=0))
return np.concatenate(lst, axis=0)
im = Image.fromarray(imgs)
enhancer = ImageEnhance.Sharpness(im)
im = enhancer.enhance(degree)
return np.array(im)
def preview_sharpening_effect(sharpening_degree=2.0, n_samples=5):
X, _ = load_data()
samples = X[np.random.randint(0, X.shape[0], n_samples)]
fig, ax = plt.subplots(n_samples, 2)
for r in range(n_samples):
sample = samples[r]
ax[r, 0].imshow(sample)
ax[r, 0].axis('off')
ax[r, 1].imshow(sharpen_images(sample, degree=sharpening_degree))
ax[r, 1].axis('off')
plt.show()
preview_sharpening_effect()
def generate_latent_points(n_samples, latent_dim=100):
return np.random.randn(n_samples, latent_dim)
def generate_fake_samples(generator, batch_size, latent_dim=100):
random_latent_vectors = generate_latent_points(batch_size, latent_dim)
X = generator.predict(random_latent_vectors)
return X
def generate_real_samples(batch_size, X):
samples = np.random.randint(0, X.shape[0], batch_size)
return X[samples]
def smoothen_labels(y, range_=(0.8, 1.1)):
return y - 1 + range_[0] + np.random.random(y.shape) * (range_[1] - range_[0])
def generate_gif(class_, rows=5, cols=6, n_frames=20):
model_dir = ARTIFACT_STORE + class_ + '_dcgan/generator/*.h5'
filenames = []
random_latent_vectors = generate_latent_points(n_samples=rows * cols)
for i, model_file in enumerate(sorted(glob.glob(model_dir), key=lambda m: int(m.split('\\')[-1].split('.')[0]))):
model = load_model(model_file, {'ConvTransposeBlock': ConvTransposeBlock}, compile=False)
X = unsqueeze_images(model.predict(random_latent_vectors))
fig, ax = plt.subplots(rows, cols)
for r in range(rows):
for c in range(cols):
ax[r, c].imshow(X[r * cols + c])
ax[r, c].axis('off')
tmp_file = f'{i}.png'
for _ in range(n_frames):
filenames.append(tmp_file)
fig.suptitle(f'{(i + 1) * 20} epochs')
plt.savefig(tmp_file)
plt.close()
# build gif
with imageio.get_writer(f'../images/gif/{class_}.gif', mode='I') as writer:
for filename in filenames:
image = imageio.imread(filename)
writer.append_data(image)
# Remove files
for filename in set(filenames):
os.remove(filename)
def generate_images(class_, version):
model_file = f'{ARTIFACT_STORE}{class_}_dcgan/generator/{version}.h5'
model = load_model(model_file, {'ConvTransposeBlock': ConvTransposeBlock}, compile=False)
for i in range(5):
X = generate_fake_samples(generator=model, batch_size=25)
X = unsqueeze_images(X)
fig, ax = plt.subplots(4, 5)
for r in range(4):
for c in range(5):
ax[r, c].imshow(X[r * 5 + c])
ax[r, c].axis('off')
fig.savefig(f'../images/png/{class_}_{i}.png')
plt.close()
def inspect_class_distributions():
X, y = load_data()
s = pd.Series(data=y.flatten())
df = pd.DataFrame(data={
'Class': CIFAR10_CLASSES,
'Count': s.groupby(s).count().values
})
ax = sns.barplot(data=df, x='Class', y='Count')
ax.set_xticklabels(ax.get_xticklabels(), rotation=40)
inspect_class_distributions()
def plot_samples(rows, cols):
X, y = load_data()
mask = np.random.randint(0, X.shape[0], rows * cols)
samples = X[mask]
labels = y.flatten()[mask]
fig, ax = plt.subplots(rows, cols, figsize=(10, 8))
for r in range(rows):
for c in range(cols):
ax[r, c].imshow(samples[r * cols + c])
ax[r, c].axis('off')
ax[r, c].set_title(CIFAR10_CLASSES[labels[r * cols + c]])
plot_samples(6, 8)
class BaseGAN(ABC):
def __init__(self, name, artifact_directory, input_shape, latent_dim):
self.input_shape = input_shape
self.latent_dim = latent_dim
self.name = name
self.artifact_directory = artifact_directory
self.generator = self.build_generator()
self.discriminator = self.build_discriminator()
self.gan = self.build_gan()
Path(artifact_directory + 'checkpoints').mkdir(parents=True, exist_ok=True)
Path(artifact_directory + 'generated_images').mkdir(parents=True, exist_ok=True)
checkpoint_dir = f'{artifact_directory}checkpoints'
checkpoint = tf.train.Checkpoint(
generator=self.generator,
discriminator=self.discriminator,
gan = self.gan
)
prev_generated_images = os.listdir(f'{artifact_directory}generated_images')
prev_generated_images = list(filter(lambda im: name in im, prev_generated_images))
self.warm_start = 0
if len(prev_generated_images) > 0:
prev_trained_epochs = [int(im.replace(f'{name}_', '').split('.')[0]) for im in prev_generated_images]
self.warm_start = sorted(prev_trained_epochs)[-1]
self.checkpoint_manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=10)
if self.checkpoint_manager.latest_checkpoint:
checkpoint.restore(self.checkpoint_manager.latest_checkpoint)
print ('Resuming training from last checkpoint...')
@abstractmethod
def build_discriminator(self):
pass
@abstractmethod
def build_generator(self):
pass
@abstractmethod
def build_gan(self):
pass
def save_images(self, epoch, rows=5, cols=5):
noise = np.random.randn(rows * cols, self.latent_dim)
gen_imgs = self.generator.predict(noise)
gen_imgs = unsqueeze_images(gen_imgs)
fig, axs = plt.subplots(rows, cols)
for i in range(rows):
for j in range(cols):
axs[i, j].imshow(gen_imgs[i * cols + j])
axs[i, j].axis('off')
fig.savefig(f'{self.artifact_directory}generated_images/{self.name}_{epoch}.png')
plt.close()
def checkpoint(self, epoch, save_gen=True, save_disc=True, save_gan=True, **kwargs):
self.save_images(epoch=epoch, **kwargs)
if save_gen:
self.generator.save(f'{self.artifact_directory}generator/{epoch}.h5')
if save_disc:
self.discriminator.save(f'{self.artifact_directory}discriminator/{epoch}.h5')
if save_gan:
self.gan.save(f'{self.artifact_directory}gan/{epoch}.h5')
self.checkpoint_manager.save()
def summarize_performance(self, X, epoch, n_samples, y_real=None):
if y_real is None:
y_real = tf.ones(shape=(n_samples, 1))
y_fake = tf.zeros(shape=(n_samples, 1))
X_real = generate_real_samples(n_samples, X)
d_loss_real, acc_real = self.discriminator.evaluate(X_real, y_real, verbose=0)
X_fake = generate_fake_samples(self.generator, n_samples, self.latent_dim)
d_loss_fake, acc_fake = self.discriminator.evaluate(X_fake, y_fake, verbose=0)
print(f'\nDiscriminator Accuracy: Real ({acc_real * 100.0}%) Fake ({acc_fake * 100.0}%)')
print(f'Discriminator Loss: Real ({d_loss_real}) Fake ({d_loss_fake})')
@abstractmethod
def train(self):
pass
class ConvBlock(tf.keras.layers.Layer):
def __init__(self, filters=32, init=None, **kwargs):
super(ConvBlock, self).__init__(**kwargs)
self.filters = filters
self.conv = Conv2D(filters=filters, kernel_size=4, strides=2, padding='same', kernel_initializer=init)
self.act = LeakyReLU(alpha=0.2)
def call(self, x):
x = self.conv(x)
x = self.act(x)
return x
def get_config(self):
config = super().get_config()
config.update({
"filters": self.filters
})
return config
class ConvTransposeBlock(tf.keras.layers.Layer):
def __init__(self, filters=32, init=None, **kwargs):
super(ConvTransposeBlock, self).__init__(**kwargs)
self.filters = filters
self.conv_T = Conv2DTranspose(filters=filters, kernel_size=4, strides=2, padding='same', kernel_initializer=init)
self.act = LeakyReLU(alpha=0.2)
def call(self, x):
x = self.conv_T(x)
x = self.act(x)
return x
def get_config(self):
config = super().get_config()
config.update({
"filters": self.filters
})
return config
class DCGAN(BaseGAN):
def __init__(self, class_, drop, input_shape=(32, 32, 3), latent_dim=100):
super().__init__(name=class_ + '_dcgan',
artifact_directory=ARTIFACT_STORE + class_ + '_dcgan/',
input_shape=input_shape,
latent_dim=latent_dim)
self.data_gen = ImageDataGenerator(
rotation_range=20,
width_shift_range=0.1,
height_shift_range=0.1,
horizontal_flip=True)
self.X = sub_partition_data(class_=class_, drop=drop)
self.X = sharpen_images(self.X, degree=2.0)
self.X = squeeze_images(self.X)
self.data_gen.fit(self.X)
def build_discriminator(self):
disc = Sequential(layers=[
Input(shape=self.input_shape),
ConvBlock(filters=64, init='he_normal'),
ConvBlock(filters=128, init='he_normal'),
ConvBlock(filters=128, init='he_normal'),
ConvBlock(filters=256, init='he_normal'),
Flatten(),
Dropout(rate=0.4),
Dense(units=1, activation='sigmoid')
])
disc.compile(optimizer=Adam(learning_rate=2e-4, beta_1=0.5),
loss='binary_crossentropy',
metrics=['accuracy'])
return disc
def build_generator(self):
gen = Sequential(layers=[
Input(shape=(self.latent_dim,)),
Dense(units=4 * 4 * 128),
Reshape(target_shape=(4, 4, 128)),
ConvTransposeBlock(filters=128),
ConvTransposeBlock(filters=128),
ConvTransposeBlock(filters=128),
Dropout(rate=0.2),
Conv2D(filters=3, kernel_size=3, padding='same', activation='tanh')
])
return gen
def build_gan(self):
self.discriminator.trainable = False
gan = Sequential(layers=[
self.generator,
self.discriminator
])
gan.compile(
optimizer=Adam(learning_rate=2e-4, beta_1=0.5),
loss='binary_crossentropy')
return gan
def train(self, epochs=1000, batch_size=128, n_batches=100, checkpoint_interval=20, summary_interval=1):
y_fake = tf.zeros(shape=(batch_size, 1))
for epoch in range(1 + self.warm_start, epochs + 1):
for i, (X_real, y_real) in enumerate(self.data_gen.flow(self.X, tf.ones(self.X.shape[0], 1), batch_size=batch_size)):
if i == n_batches:
break
i += 1
print(f'\rEpoch {epoch}/{epochs}: Batch {i}/{n_batches}', end='')
self.discriminator.train_on_batch(X_real, y_real)
X_fake = generate_fake_samples(self.generator, batch_size)
self.discriminator.train_on_batch(X_fake, y_fake)
X_noise = generate_latent_points(batch_size * 2)
g_loss = self.gan.train_on_batch(X_noise, smoothen_labels(tf.ones(batch_size * 2, 1)))
if epoch % checkpoint_interval == 0:
self.checkpoint(epoch=epoch)
if epoch % summary_interval == 0:
self.summarize_performance(self.X, epoch, batch_size)
test_gan = DCGAN('Airplane', drop=[0])
test_gan.generator.summary()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense (Dense) (None, 2048) 206848
reshape (Reshape) (None, 4, 4, 128) 0
conv_transpose_block (ConvT (None, 8, 8, 128) 262272
ransposeBlock)
conv_transpose_block_1 (Con (None, 16, 16, 128) 262272
vTransposeBlock)
conv_transpose_block_2 (Con (None, 32, 32, 128) 262272
vTransposeBlock)
dropout (Dropout) (None, 32, 32, 128) 0
conv2d (Conv2D) (None, 32, 32, 3) 3459
=================================================================
Total params: 997,123
Trainable params: 997,123
Non-trainable params: 0
_________________________________________________________________
# plot_model(test_gan.generator, to_file='generator_plot.png', show_shapes=True, show_layer_names=True)

test_gan.discriminator.summary()
Model: "sequential_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv_block (ConvBlock) (None, 16, 16, 64) 3136
conv_block_1 (ConvBlock) (None, 8, 8, 128) 131200
conv_block_2 (ConvBlock) (None, 4, 4, 128) 262272
conv_block_3 (ConvBlock) (None, 2, 2, 256) 524544
flatten (Flatten) (None, 1024) 0
dropout_1 (Dropout) (None, 1024) 0
dense_1 (Dense) (None, 1) 1025
=================================================================
Total params: 922,177
Trainable params: 0
Non-trainable params: 922,177
_________________________________________________________________
# plot_model(test_gan.discriminator, to_file='discriminator_plot.png', show_shapes=True, show_layer_names=True)

sub_partition_data('Airplane')
# airplane_gan = DCGAN(class_='Airplane', drop=[5, 6, 7])
# airplane_gan.train(260)
# generate_images('Airplane', 200)
# generate_gif('Airplane')

sub_partition_data('Automobile')
# automobile_gan = DCGAN(class_='Automobile', drop=[0, 8])
# automobile_gan.train(200)
# generate_images('Automobile', 200)
# generate_gif('Automobile')

sub_partition_data('Bird')
# bird_gan = DCGAN(class_='Bird', drop=[0, 2, 7, 8, 9])
# bird_gan.train(200)
# generate_images('Bird', 200)
# generate_gif('Bird')

sub_partition_data('Cat')
# cat_gan = DCGAN(class_='Cat', drop=[0, 1, 3, 4, 6, 7])
# cat_gan.train(260, checkpoint_interval=5)
# generate_images('Cat', 230)
# generate_gif('Cat')

sub_partition_data('Deer')
# deer_gan = DCGAN(class_='Deer', drop=[1, 2, 6, 8, 9])
# deer_gan.train(240, checkpoint_interval=5)
# generate_images('Deer', 175)
# generate_gif('Deer')

sub_partition_data('Dog')
# dog_gan = DCGAN(class_='Dog', drop=[4])
# dog_gan.train(240, checkpoint_interval=5)
# generate_images('Dog', 210)
# generate_gif('Dog')

sub_partition_data('Frog')
# frog_gan = DCGAN(class_='Frog', drop=[0, 1, 8])
# frog_gan.train(220, checkpoint_interval=10)
# generate_images('Frog', 200)
# generate_gif('Frog')

sub_partition_data('Horse')
# horse_gan = DCGAN(class_='Horse', drop=[2, 3, 4])
# horse_gan.train(240, checkpoint_interval=10)
# generate_images('Horse', 220)
# generate_gif('Horse')

sub_partition_data('Ship')
# ship_gan = DCGAN(class_='Ship', drop=[0, 1, 6, 9])
# ship_gan.train(260, checkpoint_interval=20)
# generate_images('Ship', 260)
# generate_gif('Ship')

sub_partition_data('Truck')
# truck_gan = DCGAN(class_='Truck', drop=[0, 3, 4])
# truck_gan.train(260, checkpoint_interval=20)
# generate_images('Truck', 260)
# generate_gif('Truck')

best_versions = {
'Airplane': 200,
'Automobile': 200,
'Bird': 200,
'Cat': 230,
'Deer': 175,
'Dog': 210,
'Frog': 210,
'Horse': 220,
'Ship': 260,
'Truck': 260
}
def scale_images(images, new_shape):
images_list = []
for image in images:
new_image = resize(image, new_shape, 0)
images_list.append(new_image)
return np.asarray(images_list)
def calculate_inception_score(images, n_split=10, eps=1e-16):
model = InceptionV3()
yhat = model.predict(images)
scores = []
n_part = floor(images.shape[0] / n_split)
for i in range(n_split):
ix_start, ix_end = i * n_part, i * n_part + n_part
p_yx = yhat[ix_start:ix_end]
p_y = np.expand_dims(p_yx.mean(axis=0), 0)
kl_d = p_yx * (np.log(p_yx + eps) - np.log(p_y + eps))
sum_kl_d = kl_d.sum(axis=1)
avg_kl_d = np.mean(sum_kl_d)
is_score = np.exp(avg_kl_d)
scores.append(is_score)
return np.mean(scores)
def get_inception_scores(best_versions, batch_size=1000, verbose=True):
real_images = []
fake_images = []
for class_, version in best_versions.items():
model_file = f'{ARTIFACT_STORE}{class_}_dcgan/generator/{version}.h5'
generator_model = load_model(model_file, {'ConvTransposeBlock': ConvTransposeBlock}, compile=False)
X = partition_data(class_)
batch_size = batch_size
real_images.append(generate_real_samples(batch_size=batch_size, X=X).astype('float32'))
fake_images.append(unsqueeze_images(generate_fake_samples(generator=generator_model, batch_size=batch_size)).astype('float32'))
real_images = np.concatenate(real_images, axis=0)
fake_images = np.concatenate(fake_images, axis=0)
if verbose:
print(f'Original Shapes - Real: {real_images.shape} Fake: {fake_images.shape}')
real_images = preprocess_input(scale_images(real_images, (299, 299, 3)))
fake_images = preprocess_input(scale_images(fake_images, (299, 299, 3)))
if verbose:
print(f'Shapes after Processing - Real: {real_images.shape} Fake: {fake_images.shape}')
is_real = calculate_inception_score(real_images)
if verbose:
print(f'Inception Score (real images): {is_real}')
is_fake = calculate_inception_score(fake_images)
if verbose:
print(f'Inception Score (fake images): {is_fake}')
return is_real, is_fake
get_inception_scores(best_versions, batch_size=1000)
Original Shapes - Real: (10000, 32, 32, 3) Fake: (10000, 32, 32, 3) Shapes after Processing - Real: (10000, 299, 299, 3) Fake: (10000, 299, 299, 3) Inception Score (real images): 5.108252048492432 Inception Score (fake images): 4.2629218101501465
(5.108252, 4.262922)
def calculate_fid(model, images1, images2):
act1 = model.predict(images1)
act2 = model.predict(images2)
mu1, sigma1 = act1.mean(axis=0), np.cov(act1, rowvar=False)
mu2, sigma2 = act2.mean(axis=0), np.cov(act2, rowvar=False)
ssdiff = np.sum((mu1 - mu2)**2.0)
covmean = sqrtm(sigma1.dot(sigma2))
if np.iscomplexobj(covmean):
covmean = covmean.real
fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
return fid
def get_fid_scores(class_, version, verbose=True):
model_file = f'{ARTIFACT_STORE}{class_}_dcgan/generator/{version}.h5'
generator_model = load_model(model_file, {'ConvTransposeBlock': ConvTransposeBlock}, compile=False)
inception_model = InceptionV3(include_top=False, pooling='avg', input_shape=(299, 299, 3))
X = partition_data(class_)
batch_size = 100
real_images = generate_real_samples(batch_size=batch_size, X=X).astype('float32')
fake_images = unsqueeze_images(generate_fake_samples(generator=generator_model, batch_size=batch_size)).astype('float32')
if verbose:
print(f'Original Shapes - Real: {real_images.shape} Fake: {fake_images.shape}')
real_images = preprocess_input(scale_images(real_images, (299, 299, 3)))
fake_images = preprocess_input(scale_images(fake_images, (299, 299, 3)))
if verbose:
print(f'Shapes after Processing - Real: {real_images.shape} Fake: {fake_images.shape}')
fid_real = calculate_fid(inception_model, real_images, real_images)
if verbose:
print(f'FID (real images): {fid_real}')
fid_fake = calculate_fid(inception_model, real_images, fake_images)
if verbose:
print(f'FID (fake images): {fid_fake}')
return fid_real, fid_fake
def plot_fid_scores(best_versions):
class_names = []
fid_scores = []
for class_, version in best_versions.items():
_, fid_score = get_fid_scores(class_, version, verbose=False)
class_names.append(class_)
fid_scores.append(fid_score)
sns.barplot(x=class_names, y=fid_scores)
plt.show()
plot_fid_scores(best_versions=best_versions)